import pandas as pd
import numpy as np
import statsmodels.api as sm
import plotly.graph_objects as go
from plotly.subplots import make_subplots

table=pd.read_csv('Supplemental Table2.csv')

#fit the OLS model
model=sm.OLS(endog=table['radical'],
	   exog=table[['const','anger','disgust','fear','sadness','identity',
				   'protests','arrests','rounds','protests_sq','arrests_sq','rounds_sq',
				   'protests_identity','arrests_identity','rounds_identity']]).fit()
#show model resutls
model.summary()
model.save('Supplemental Table2.h5')

#-------------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------#

# Marginal Effect Plots
y_err = table['radical'].values -model.predict() # calculate the residuals
s_err = np.sum(np.power(y_err,2))  # sum of the squares of the residuals

p_y={}
p_x={}
lower_90={}
lower_95={}
upper_90={}
upper_95={}

for i in ['protests', 'arrests', 'rounds']:
	x=table[i]
	# create series of new test x-values to predict for
	p_x[i] = np.arange(round(min(x),2),round(max(x),2),0.01)
	print(i,round(max(x),2))
	
	# calculate confidence intervals for new test x-series
	mean_x = np.mean(x)		 # mean of x
	n = len(x)				  # number of samples in origional model
	tstat_90 = t.ppf(0.9, n-1)
	tstat_95 = t.ppf(0.95, n-1)
	
	confs_90 = tstat_90 * np.sqrt((s_err/(n-2))*(1.0/n + (np.power((p_x[i]-mean_x),2)/
			((np.sum(np.power(x,2)))-n*(np.power(mean_x,2))))))
	confs_95 = tstat_95 * np.sqrt((s_err/(n-2))*(1.0/n + (np.power((p_x[i]-mean_x),2)/
			((np.sum(np.power(x,2)))-n*(np.power(mean_x,2))))))
	
	# predict y based on test x-values
	p_y_interact[i] = params[i]*p_x[i]+params[i+'_sq']*(p_x[i]**2)+params[i+'_identity']*p_x[i]
	p_y[i] = params[i]*p_x[i]+params[i+'_sq']*(p_x[i]**2)
	
	lower_90[i] = list(p_y[i] - abs(confs_90))
	upper_90[i] = list(p_y[i] + abs(confs_90))
	lower_95[i] = list(p_y[i] - abs(confs_95))
	upper_95[i] = list(p_y[i] + abs(confs_95))
	p_y[i]=list(p_y[i])
	p_x[i]=list(p_x[i])

	lower_90_interact[i] = list(p_y_interact[i] - abs(confs_90))
	upper_90_interact[i] = list(p_y_interact[i] + abs(confs_90))
	lower_95_interact[i] = list(p_y_interact[i] - abs(confs_95))
	upper_95_interact[i] = list(p_y_interact[i] + abs(confs_95))
	p_y_interact[i]=list(p_y_interact[i])

#-------------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------#

# Plot the marginal effect graph without interaction
# set up the plot
fig = make_subplots(
	rows=1, cols=3,subplot_titles=['<i><b>Rounds</b></i>','<i><b>Arrests</b></i>','<i><b>Protest Size</b></i>'],vertical_spacing=0.0002,shared_yaxes=True,
	print_grid=True)

k=1 # column flag

for i in ['rounds', 'arrests', 'protests']:
	# add zero lines
	fig.add_trace(go.Scatter(x=p_x[i],y=[0]*len(p_x[i]),mode='lines',line=dict(color='grey',width=0.5),name='Zero Line',showlegend=False),row=1,col=k)
	# add reference lines
	fig.add_trace(go.Scatter(x=[p_x[i][np.argsort(np.abs(np.array(p_y[i])))[1]],p_x[i][np.argsort(np.abs(np.array(p_y[i])))[1]]],y=[-0.05,0.12],mode='lines',line=dict(color='black',width=1,dash='dot'),showlegend=False),row=1,col=k)
	
	# add lower ranges of CIs
	fig.add_trace(go.Scatter(x=p_x[i],y=lower_90[i],mode='lines',line=dict(color='black',width=0.2),name='Lower confidence limit (90%)',showlegend=False),row=1,col=k)
	fig.add_trace(go.Scatter(x=p_x[i],y=lower_95[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Lower confidence limit (95%)',showlegend=False),row=1,col=k)
	
	# add upper ranges of CIs
	fig.add_trace(go.Scatter(x=p_x[i],y=upper_95[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Upper confidence limit (95%)',showlegend=False),row=1,col=k)
	fig.add_trace(go.Scatter(x=p_x[i],y=upper_90[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Upper confidence limit (90%)',showlegend=False),row=1,col=k)

	# add regression line
	fig.add_trace(go.Scatter(x=p_x[i],y=p_y[i],mode='lines',line=dict(color='black',width=1),name='Regression Line',showlegend=False),row=1,col=k)
	
	k+=1 # update column flag

# show the plot
fig.update_layout(width=900,height=400,template='plotly_white',showlegend=True,legend=dict(orientation='h'),margin=dict(l=5, r=5, t=20, b=5))
fig.update_xaxes(showline=True,linecolor='black',linewidth=1,ticks='outside',mirror=True)
fig.update_yaxes(range=[-0.02,0.023],showline=True,linecolor='black',linewidth=1,ticks='outside',mirror='all')
fig.update_yaxes(title='<b><i>First-Difference Model</i></b>',titlefont=dict(size=16),row=1,col=1)
fig.update_xaxes(tickmode = 'array',tickvals=[0,2,4,6],ticktext=['10<sup>0</sup>','10<sup>2</sup>','10<sup>4</sup>','10<sup>6</sup>'],row=1,col=3)
fig.update_xaxes(tickmode = 'array',tickvals=[0,1,2,3,4],ticktext=['10<sup>0</sup>','10<sup>1</sup>','10<sup>2</sup>','10<sup>3</sup>','10<sup>4</sup>'],row=1,col=2)
fig.update_xaxes(tickmode = 'array',tickvals=[0,1,2,3,4],ticktext=['10<sup>0</sup>','10<sup>1</sup>','10<sup>2</sup>','10<sup>3</sup>','10<sup>4</sup>'],row=1,col=1)

fig.save('Supplemental Figure S2.jpg')

#-------------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------#

# Plot the marginal effect graph with interaction
# set up plot
fig = make_subplots(
	rows=1, cols=3,subplot_titles=['<i><b>Rounds</b></i>','<i><b>Arrests</b></i>','<i><b>Protest Size</b></i>'],vertical_spacing=0.0002,shared_yaxes=True,
	print_grid=True)

k=1 # column flag

for i in ['rounds', 'arrests', 'protests']:
	# add zero lines
	fig.add_trace(go.Scatter(x=p_x[i],y=[0]*len(p_x[i]),mode='lines',line=dict(color='grey',width=0.5),name='Zero Line',showlegend=False),row=1,col=k)
	# add reference lines
	fig.add_trace(go.Scatter(x=[p_x[i][np.argsort(np.abs(np.array(p_y_interact[i])))[1]],p_x[i][np.argsort(np.abs(np.array(p_y_interact[i])))[1]]],y=[-0.05,0.12],mode='lines',line=dict(color='black',width=1,dash='dot'),showlegend=False),row=1,col=k)
	
	# add lower ranges of CIs
	fig.add_trace(go.Scatter(x=p_x[i],y=lower_90_interact[i],mode='lines',line=dict(color='black',width=0.2),name='Lower confidence limit (90%)',showlegend=False),row=1,col=k)
	fig.add_trace(go.Scatter(x=p_x[i],y=lower_95_interact[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Lower confidence limit (95%)',showlegend=False),row=1,col=k)
	
	# add upper ranges of CIs
	fig.add_trace(go.Scatter(x=p_x[i],y=upper_95_interact[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Upper confidence limit (95%)',showlegend=False),row=1,col=k)
	fig.add_trace(go.Scatter(x=p_x[i],y=upper_90_interact[i],mode='lines',line=dict(color='black',width=0.2),fill='tonexty',name='Upper confidence limit (90%)',showlegend=False),row=1,col=k)

	# add regression line
	fig.add_trace(go.Scatter(x=p_x[i],y=p_y_interact[i],mode='lines',line=dict(color='black',width=1),name='Regression Line',showlegend=False),row=1,col=k)
	
	k+=1 # update column flag

# show the plot
fig.update_layout(width=900,height=400,template='plotly_white',showlegend=True,legend=dict(orientation='h'),margin=dict(l=5, r=5, t=20, b=5))
fig.update_xaxes(showline=True,linecolor='black',linewidth=1,ticks='outside',mirror=True)
fig.update_yaxes(range=[-0.02,0.023],showline=True,linecolor='black',linewidth=1,ticks='outside',mirror='all')
fig.update_yaxes(title='<b><i>First-Difference Model</i></b>',titlefont=dict(size=16),row=1,col=1)
fig.update_xaxes(tickmode = 'array',tickvals=[0,2,4,6],ticktext=['10<sup>0</sup>','10<sup>2</sup>','10<sup>4</sup>','10<sup>6</sup>'],row=1,col=3)
fig.update_xaxes(tickmode = 'array',tickvals=[0,1,2,3,4],ticktext=['10<sup>0</sup>','10<sup>1</sup>','10<sup>2</sup>','10<sup>3</sup>','10<sup>4</sup>'],row=1,col=2)
fig.update_xaxes(tickmode = 'array',tickvals=[0,1,2,3,4],ticktext=['10<sup>0</sup>','10<sup>1</sup>','10<sup>2</sup>','10<sup>3</sup>','10<sup>4</sup>'],row=1,col=1)

fig.save('Supplemental Figure S2_interaction.jpg')